// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#pragma once

#include <vector>

#include <pollux/common/base/async_source.h>
#include <pollux/common/base/exceptions.h>
#include <pollux/common/base/portability.h>
#include <pollux/common/base/succinct_printer.h>
#include <pollux/common/future/pollux_promise.h>
#include <pollux/common/time/timer.h>

namespace kumo::pollux::memory {
    class MemoryPool;
    class ArbitrationOperation;

#define POLLUX_MEM_POOL_CAP_EXCEEDED(errorMessage)                   \
  _POLLUX_THROW(                                                     \
      ::kumo::pollux::PolluxRuntimeError,                         \
      ::kumo::pollux::error_source::kErrorSourceRuntime.c_str(), \
      ::kumo::pollux::error_code::kMemCapExceeded.c_str(),       \
      /* isRetriable */ true,                                       \
      "{}",                                                         \
      errorMessage);

#define POLLUX_MEM_ARBITRATION_FAILED(errorMessage)                   \
  _POLLUX_THROW(                                                      \
      ::kumo::pollux::PolluxRuntimeError,                          \
      ::kumo::pollux::error_source::kErrorSourceRuntime.c_str(),  \
      ::kumo::pollux::error_code::kMemArbitrationFailure.c_str(), \
      /* isRetriable */ true,                                        \
      "{}",                                                          \
      errorMessage);

#define POLLUX_MEM_POOL_ABORTED(errorMessage)                        \
  _POLLUX_THROW(                                                     \
      ::kumo::pollux::PolluxRuntimeError,                         \
      ::kumo::pollux::error_source::kErrorSourceRuntime.c_str(), \
      ::kumo::pollux::error_code::kMemAborted.c_str(),           \
      /* isRetriable */ true,                                       \
      "{}",                                                         \
      errorMessage);

    using MemoryArbitrationStateCheckCB = std::function<void(MemoryPool &)>;

    /// The memory arbitrator interface. There is one memory arbitrator object per
    /// memory manager which is responsible for arbitrating memory usage among the
    /// query memory pools for query memory isolation. When a memory pool exceeds
    /// its capacity limit, it sends a memory grow request to the memory manager
    /// which forwards the request to the memory arbitrator behind along with all
    /// the query memory pools as the arbitration candidates. The arbitrator tries
    /// to free up the overused memory from the selected candidates according to the
    /// supported query memory isolation policy. The memory arbitrator can free up
    /// memory by either reclaiming the used memory from a running query through
    /// techniques such as disk spilling or aborting a query. Different memory
    /// arbitrator implementations achieve different query memory isolation policy
    /// (see Kind definition below).
    class MemoryArbitrator {
    public:
        struct Config {
            /// The string kind of this memory arbitrator.
            ///
            /// NOTE: If kind is not set, a noop arbitrator is created which grants the
            /// maximum capacity to each newly created memory pool.
            std::string kind{};

            /// The total memory capacity in bytes of all the running queries.
            ///
            /// NOTE: this should be same capacity as we set in the associated memory
            /// manager.
            int64_t capacity;

            /// Callback to check if a memory arbitration request is issued from a
            /// driver thread, then the driver should be put in suspended state to avoid
            /// the potential deadlock when reclaim memory from the task of the request
            /// memory pool.
            MemoryArbitrationStateCheckCB arbitrationStateCheckCb{nullptr};

            /// Additional configs that are arbitrator implementation specific.
            melon::F14FastMap<std::string, std::string> extraConfigs{};

            std::string toString() const {
                std::stringstream ss;
                for (const auto &extraConfig: extraConfigs) {
                    ss << extraConfig.first << "=" << extraConfig.second << ";";
                }
                return fmt::format(
                    "kind={};capacity={};arbitrationStateCheckCb={};{}",
                    kind,
                    succinctBytes(capacity),
                    (arbitrationStateCheckCb ? "(set)" : "(unset)"),
                    ss.str());
            }
        };

        using Factory = std::function<std::unique_ptr<MemoryArbitrator>(
            const MemoryArbitrator::Config &config)>;

        /// Registers factory for a specific 'kind' of memory arbitrator
        /// MemoryArbitrator::Create looks up the registry to find the factory to
        /// create arbitrator instance based on the kind specified in arbitrator
        /// config.
        ///
        /// NOTE: we only allow the same 'kind' of memory arbitrator to be registered
        /// once. The function returns false if 'kind' is already registered.
        static bool registerFactory(const std::string &kind, Factory factory);

        /// Unregisters the registered factory for a specifc kind.
        ///
        /// NOTE: the function throws if the specified arbitrator 'kind' is not
        /// registered.
        static void unregisterFactory(const std::string &kind);

        /// Invoked by the memory manager to create an instance of memory arbitrator
        /// based on the kind specified in 'config'. The arbitrator kind must be
        /// registered through MemoryArbitrator::registerFactory(), otherwise the
        /// function throws a pollux user exception error.
        ///
        /// NOTE: if arbitrator kind is not set in 'config', then the function returns
        /// nullptr, and the memory arbitration function is disabled.
        static std::unique_ptr<MemoryArbitrator> create(const Config &config);

        virtual std::string kind() const = 0;

        uint64_t capacity() const {
            return config_.capacity;
        }

        const Config &config() const {
            return config_;
        }

        virtual ~MemoryArbitrator() = default;

        /// Invoked by the memory manager to shutdown the memory arbitrator to stop
        /// serving new memory arbitration requests.
        virtual void shutdown() = 0;

        /// Invoked by the memory manager to add a newly created memory pool. The
        /// memory arbitrator allocates the initial capacity for 'pool' and
        /// dynamically adjusts its capacity based query memory needs through memory
        /// arbitration.
        virtual void addPool(const std::shared_ptr<MemoryPool> &pool) = 0;

        /// Invoked by the memory manager to remove a destroyed memory pool. The
        /// memory arbitrator frees up all its capacity and stops memory arbitration
        /// operation on it.
        virtual void removePool(MemoryPool *pool) = 0;

        /// Invoked by the memory manager to grow a memory pool's capacity.
        /// 'pool' is the memory pool to request to grow. The memory arbitrator picks
        /// up a number of pools to either shrink its memory capacity without actually
        /// freeing memory or reclaim its used memory to free up enough memory for
        /// 'requestor' to grow.
        virtual void growCapacity(MemoryPool *pool, uint64_t requestBytes) = 0;

        /// Invoked by the memory manager to shrink up to 'targetBytes' free capacity
        /// from a memory 'pool', and returns them back to the arbitrator. If
        /// 'targetBytes' is zero, we shrink all the free capacity from the memory
        /// pool. The function returns the actual freed capacity from 'pool'.
        virtual uint64_t shrinkCapacity(MemoryPool *pool, uint64_t targetBytes) = 0;

        /// Invoked by the memory manager to globally shrink memory from
        /// memory pools by reclaiming only used memory, to reduce system memory
        /// pressure. The freed memory capacity is given back to the arbitrator.  If
        /// 'targetBytes' is zero, then try to reclaim all the memory from 'pools'.
        /// The function returns the actual freed memory capacity in bytes. If
        /// 'allowSpill' is true, it reclaims the used memory by spilling. If
        /// 'allowAbort' is true, it reclaims the used memory by aborting the queries
        /// with the most memory usage. If both are true, it first reclaims the used
        /// memory by spilling and then abort queries to reach the reclaim target.
        ///
        /// NOTE: The actual reclaimed used memory (hence system memory) may be less
        /// than 'targetBytes' due to the accounting of free capacity reclaimed. This
        /// is okay because when this method is called, system is normally under
        /// memory pressure, and there normally isn't much free capacity to reclaim.
        /// So the reclaimed used memory in this case should be very close to
        /// 'targetBytes' if enough used memory is reclaimable. We should improve this
        /// in the future.
        virtual uint64_t shrinkCapacity(
            uint64_t targetBytes,
            bool allowSpill = true,
            bool allowAbort = false) = 0;

        /// The internal execution stats of the memory arbitrator.
        struct Stats {
            /// The number of arbitration requests.
            uint64_t numRequests{0};
            /// The number of running arbitration requests.
            uint64_t numRunning{0};
            /// The number of succeeded arbitration requests.
            uint64_t numSucceeded{0};
            /// The number of aborted arbitration requests.
            uint64_t numAborted{0};
            /// The number of arbitration request failures.
            uint64_t numFailures{0};
            /// The number of reclaimed unused free bytes.
            uint64_t reclaimedFreeBytes{0};
            /// The number of reclaimed used bytes.
            uint64_t reclaimedUsedBytes{0};
            /// The max memory capacity in bytes.
            uint64_t maxCapacityBytes{0};
            /// The free memory capacity in bytes.
            uint64_t freeCapacityBytes{0};
            /// The free reserved memory capacity in bytes.
            uint64_t freeReservedCapacityBytes{0};
            /// The total number of times of the reclaim attempts that end up failing
            /// due to reclaiming at non-reclaimable stage.
            uint64_t numNonReclaimableAttempts{0};

            Stats(
                uint64_t _numRequests,
                uint64_t _numRunning,
                uint64_t _numSucceeded,
                uint64_t _numAborted,
                uint64_t _numFailures,
                uint64_t _reclaimedFreeBytes,
                uint64_t _reclaimedUsedBytes,
                uint64_t _maxCapacityBytes,
                uint64_t _freeCapacityBytes,
                uint64_t _freeReservedCapacityBytes,
                uint64_t _numNonReclaimableAttempts);

            Stats() = default;

            Stats operator-(const Stats &other) const;

            bool operator==(const Stats &other) const;

            bool operator!=(const Stats &other) const;

            bool operator<(const Stats &other) const;

            bool operator>(const Stats &other) const;

            bool operator>=(const Stats &other) const;

            bool operator<=(const Stats &other) const;

            bool empty() const {
                return numRequests == 0;
            }

            /// Returns the debug string of this stats.
            std::string toString() const;
        };

        virtual Stats stats() const = 0;

        /// Returns the debug string of this memory arbitrator.
        virtual std::string toString() const = 0;

    protected:
        explicit MemoryArbitrator(const Config &config) : config_(config) {
        }

        /// Helper utilities used by the memory arbitrator implementations to call
        /// protected methods of memory pool.
        static bool
        growPool(MemoryPool *pool, uint64_t growBytes, uint64_t reservationBytes);

        static uint64_t shrinkPool(MemoryPool *pool, uint64_t targetBytes);

        const Config config_;
    };

    /// Formatter for fmt.
    MELON_ALWAYS_INLINE std::string format_as(MemoryArbitrator::Stats stats) {
        return stats.toString();
    }

    MELON_ALWAYS_INLINE std::ostream &operator<<(
        std::ostream &o,
        const MemoryArbitrator::Stats &stats) {
        return o << stats.toString();
    }

    /// The memory reclaimer interface is used by memory pool to participate in
    /// the memory arbitration execution (enter/leave arbitration process) as well
    /// as reclaim memory from the associated query object. We have default
    /// implementation that always reclaim memory from the child memory pool with
    /// most reclaimable memory. This is used by the query and plan node memory
    /// pools which don't need customized operations. A task memory pool needs to
    /// to pause a task execution before reclaiming memory from its child pools.
    /// This avoids any potential race condition between concurrent memory
    /// reclamation operation and the task activities. An operator memory pool needs
    /// to to put the the associated task driver thread into suspension state before
    /// entering into an arbitration process. It is because the memory arbitrator
    /// needs to pause a task execution before reclaim memory from the task. It is
    /// possible that the memory arbitration tries to reclaim memory from the task
    /// which initiates the memory arbitration request. If we don't put the driver
    /// thread into suspension state, then the memory arbitration process might
    /// run into deadlock as the task will never be paused. The operator memory pool
    /// also needs to reclaim the actually used memory from the associated operator
    /// through techniques such as disks spilling.
    class MemoryReclaimer {
    public:
        /// Used to collect memory reclaim execution stats.
        struct Stats {
            /// The total number of times of the reclaim attempts that end up failing
            /// due to reclaiming at non-reclaimable stage.
            uint64_t numNonReclaimableAttempts{0};

            /// The total time to do the reclaim in microseconds.
            uint64_t reclaimExecTimeUs{0};

            /// The total reclaimed memory bytes.
            uint64_t reclaimedBytes{0};

            /// The total time of task pause during reclaim in microseconds.
            uint64_t reclaimWaitTimeUs{0};

            void reset();

            bool operator==(const Stats &other) const;

            bool operator!=(const Stats &other) const;

            Stats &operator+=(const Stats &other);
        };

        virtual ~MemoryReclaimer() = default;

        static std::unique_ptr<MemoryReclaimer> create(int32_t priority = 0);

        /// Invoked memory reclaim function from 'pool' and record execution 'stats'.
        static uint64_t run(const std::function<int64_t()> &func, Stats &stats);

        /// Invoked by the memory arbitrator before entering the memory arbitration
        /// processing. The default implementation does nothing but user can override
        /// this if needs. For example, an operator memory reclaimer needed to put the
        /// associated driver execution into suspension state so that the task which
        /// initiates the memory arbitration request, can be paused for memory
        /// reclamation during the memory arbitration processing.
        virtual void enterArbitration() {
        }

        /// Invoked by the memory arbitrator after finishes the memory arbitration
        /// processing and is used in pair with 'enterArbitration'. For example, an
        /// operator memory pool needs to moves the associated driver execution out of
        /// the suspension state.
        ///
        /// NOTE: it is guaranteed to be called also on failure path if
        /// enterArbitration has been called.
        virtual void leaveArbitration() noexcept {
        }

        /// Invoked by upper layer reclaimer, to return the priority of this
        /// reclaimer. The priority determines the reclaiming order of self among all
        /// same level reclaimers. The smaller the number, the higher the priority.
        /// Consider the following memory pool & reclaimer structure:
        ///
        ///                 rec1(pri 1)
        ///                /            \
        ///               /              \
        ///              /                \
        ///      rec2(pri 1)           rec3(pri 3)
        ///      /        \             /        \
        ///     /          \           /          \
        /// rec4(pri 1) rec5(pri 0)  rec6(pri 0) rec7(pri 1)
        ///
        /// The reclaiming traversing order will be rec1 -> rec2 -> rec5 -> rec4 ->
        /// rec3 -> rec6 -> rec7
        virtual int32_t priority() const {
            return priority_;
        };

        /// Invoked by the memory arbitrator to get the amount of memory bytes that
        /// can be reclaimed from 'pool'. The function returns true if 'pool' is
        /// reclaimable and returns the estimated reclaimable bytes in
        /// 'reclaimableBytes'.
        virtual bool reclaimableBytes(
            const MemoryPool &pool,
            uint64_t &reclaimableBytes) const;

        /// Invoked by the memory arbitrator to reclaim from memory 'pool' with
        /// specified 'targetBytes'. It is expected to reclaim at least that amount of
        /// memory bytes but there is no guarantees. If 'targetBytes' is zero, then it
        /// reclaims all the reclaimable memory from the memory 'pool'. 'maxWaitMs'
        /// specifies the max time to wait for reclaim if not zero. The memory
        /// reclaim might fail if exceeds the timeout. The function returns the actual
        /// reclaimed memory bytes.
        ///
        /// NOTE: 'maxWaitMs' is optional and the actual memory reclaim implementation
        /// can choose to respect this timeout or not on its own.
        virtual uint64_t reclaim(
            MemoryPool *pool,
            uint64_t targetBytes,
            uint64_t maxWaitMs,
            Stats &stats);

        /// Invoked by the memory arbitrator to abort memory 'pool' and the associated
        /// query execution when encounters non-recoverable memory reclaim error or
        /// fails to reclaim enough free capacity. The abort is a synchronous
        /// operation and we expect most of used memory to be freed after the abort
        /// completes. 'error' should be passed in as the direct cause of the
        /// abortion. It will be propagated all the way to task level for accurate
        /// error exposure.
        virtual void abort(MemoryPool *pool, const std::exception_ptr &error);

    protected:
        explicit MemoryReclaimer(int32_t priority) : priority_(priority) {
        };

    private:
        const int32_t priority_;
    };

    /// Helper class used to measure the memory bytes reclaimed from a memory pool
    /// by a memory reclaim function.
    class ScopedReclaimedBytesRecorder {
    public:
        ScopedReclaimedBytesRecorder(MemoryPool *pool, int64_t *reclaimedBytes);

        ~ScopedReclaimedBytesRecorder();

    private:
        MemoryPool *const pool_;
        int64_t *const reclaimedBytes_;
        const int64_t reservedBytesBeforeReclaim_;
    };

    /// The object is used to set/clear non-reclaimable section of an operation in
    /// the middle of its execution. It allows the memory arbitrator to reclaim
    /// memory from a running operator which is waiting for memory arbitration.
    /// 'nonReclaimableSection' points to the corresponding flag of the associated
    /// operator.
    class ReclaimableSectionGuard {
    public:
        explicit ReclaimableSectionGuard(tsan_atomic<bool> *nonReclaimableSection)
            : nonReclaimableSection_(nonReclaimableSection),
              oldNonReclaimableSectionValue_(*nonReclaimableSection_) {
            *nonReclaimableSection_ = false;
        }

        ~ReclaimableSectionGuard() {
            *nonReclaimableSection_ = oldNonReclaimableSectionValue_;
        }

    private:
        tsan_atomic<bool> *const nonReclaimableSection_;
        const bool oldNonReclaimableSectionValue_;
    };

    class NonReclaimableSectionGuard {
    public:
        explicit NonReclaimableSectionGuard(tsan_atomic<bool> *nonReclaimableSection)
            : nonReclaimableSection_(nonReclaimableSection),
              oldNonReclaimableSectionValue_(*nonReclaimableSection_) {
            *nonReclaimableSection_ = true;
        }

        ~NonReclaimableSectionGuard() {
            *nonReclaimableSection_ = oldNonReclaimableSectionValue_;
        }

    private:
        tsan_atomic<bool> *const nonReclaimableSection_;
        const bool oldNonReclaimableSectionValue_;
    };

    /// The memory arbitration context which is set as per-thread local variable by
    /// memory arbitrator. It is used to indicate if a running thread is under
    /// memory arbitration. This helps to enable sanity check such as all the memory
    /// reservations during memory arbitration should come from the spilling memory
    /// pool.
    struct MemoryArbitrationContext {
        /// Defines the type of memory arbitration.
        enum class Type {
            /// Indicates the memory arbitration is triggered by a memory pool for its
            /// own capacity growth.
            kLocal,
            /// Indicates the memory arbitration is triggered by the memory arbitrator
            /// to free up memory for the system.
            kGlobal,
        };

        static std::string typeName(Type type);

        const Type type;
        /// The name of the request memory pool for local arbitration. It is empty for
        /// global memory arbitration type.
        const std::string requestorName;

        explicit MemoryArbitrationContext(const MemoryPool *requestor);

        MemoryArbitrationContext() : type(Type::kGlobal) {
        }
    };

    /// Object used to set/restore the memory arbitration context when a thread is
    /// under memory arbitration processing.
    class ScopedMemoryArbitrationContext {
    public:
        ScopedMemoryArbitrationContext();

        explicit ScopedMemoryArbitrationContext(
            const MemoryArbitrationContext *context);

        explicit ScopedMemoryArbitrationContext(const MemoryPool *requestor);

        ~ScopedMemoryArbitrationContext();

    private:
        MemoryArbitrationContext *const savedArbitrationCtx_{nullptr};
        MemoryArbitrationContext currentArbitrationCtx_;
    };

    /// Object used to setup arbitration section for a memory pool.
    class MemoryPoolArbitrationSection {
    public:
        explicit MemoryPoolArbitrationSection(MemoryPool *pool);

        ~MemoryPoolArbitrationSection();

    private:
        MemoryPool *const pool_;
    };

    /// Returns the memory arbitration context set by a per-thread local variable if
    /// the running thread is under memory arbitration processing.
    const MemoryArbitrationContext *memoryArbitrationContext();

    /// Returns true if the running thread is under memory arbitration or not.
    bool underMemoryArbitration();

    /// Creates an async memory reclaim task with memory arbitration context set.
    /// This is to avoid recursive memory arbitration during memory reclaim.
    ///
    /// NOTE: this must be called under memory arbitration.
    template<typename Item>
    std::shared_ptr<AsyncSource<Item> > createAsyncMemoryReclaimTask(
        std::function<std::unique_ptr<Item>()> task) {
        auto *arbitrationCtx = memory::memoryArbitrationContext();
        return std::make_shared<AsyncSource<Item> >(
            [asyncTask = std::move(task), arbitrationCtx]() -> std::unique_ptr<Item> {
                std::unique_ptr<ScopedMemoryArbitrationContext> restoreArbitrationCtx;
                if (arbitrationCtx != nullptr) {
                    restoreArbitrationCtx =
                            std::make_unique<ScopedMemoryArbitrationContext>(arbitrationCtx);
                }
                return asyncTask();
            });
    }

    /// The function triggers memory arbitration by shrinking memory pools from
    /// 'manager' by invoking shrinkPools API. If 'manager' is not set, then it
    /// shrinks from the process wide memory manager. If 'targetBytes' is zero, then
    /// it reclaims all the memory from 'manager' if possible. If 'allowSpill' is
    /// true, then it allows to reclaim the used memory by spilling.
    class MemoryManager;

    void testingRunArbitration(
        uint64_t targetBytes = 0,
        bool allowSpill = true,
        MemoryManager *manager = nullptr);

    /// The function triggers memory arbitration by shrinking memory pools from
    /// 'manager' of 'pool' by invoking its shrinkPools API. If 'targetBytes' is
    /// zero, then it reclaims all the memory from 'manager' if possible. If
    /// 'allowSpill' is true, then it allows to reclaim the used memory by spilling.
    void testingRunArbitration(
        MemoryPool *pool,
        uint64_t targetBytes = 0,
        bool allowSpill = true);
} // namespace kumo::pollux::memory

#if FMT_VERSION < 100100
template <>
struct fmt::formatter<kumo::pollux::memory::MemoryArbitrator::Stats>
    : formatter<std::string> {
  auto format(
      kumo::pollux::memory::MemoryArbitrator::Stats s,
      format_context& ctx) const {
    return formatter<std::string>::format(s.toString(), ctx);
  }
};
#endif
